﻿# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import os
import cv2
import config
import numpy as np
import zipfile
import PIL.Image
import json
import torch
import dnnlib
from tqdm import tqdm
from natsort import natsorted
from glob import glob
from torch.utils.data import Dataset
import torchvision.transforms as transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

try:
    import pyspng
except ImportError:
    pyspng = None

class EEGEncoder(nn.Module):
        def __init__(self, model=None, pretrained=True, trainable=True):
            super().__init__()

            model = SimpleViT(
                image_size = (128,440),
                patch_size = (8, 20),
                num_classes = 40,
                dim = 256,
                depth = 4,
                dim_head=16,
                heads = 16,
                mlp_dim = 16,
                channels = 1
            ).to(device)
            self.model = model
            self.model = torch.nn.DataParallel(self.model).to(device)
            eegckpt   = '/home/ubuntu/chkpts/eeg_best_model.pth'
            eegcheckpoint = torch.load(eegckpt, map_location=device)
            self.model.load_state_dict(eegcheckpoint['model_state_dict'])
          
            
        def forward(self, x):
              return self.model(x)


class ImgEncoder(nn.Module):
    def __init__(self, model_name="google/vit-base-patch16-224", pretrained=True, trainable=True):
        super().__init__()
        self.mlp_head = nn.Sequential(
            nn.Linear(768, 256)
        )
        if pretrained:
            self.model = ViTModel.from_pretrained(model_name)
        else:
            self.model = ViTModel(config=ViTConfig())

        for p in self.model.parameters():
            p.requires_grad = trainable
            
        
            

    def forward(self, pixel_values):
        output = self.model(pixel_values=pixel_values)
        last_hidden_state = output.last_hidden_state
        cls_embedding = last_hidden_state[:, 0, :] 
        return self.mlp_head(cls_embedding)



class ProjectionHead(nn.Module):
    def __init__(
        self,
        embedding_dim=256,
        projection_dim=128,
        dropout=.1
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x



class CLIPModel(nn.Module):
    def __init__(
        self,
        temperature=1,
        EEG_embedding=256,
        ImgEncoder_dim=256,
        image_encoder=None,
        text_encoder=None,
        image_projection=None,
        text_projection=None,
    ):
        super().__init__()
        self.eeg_encoder = EEGEncoder().to(device)
        self.img_encoder = ImgEncoder().to(device)  # Updated to img_encoder
        self.img_encoder     = torch.nn.DataParallel(self.img_encoder).to(device)
        eegckpt   = image_encoder
        eegcheckpoint = torch.load(eegckpt, map_location=device)
        self.img_encoder.load_state_dict(eegcheckpoint['model_state_dict'])
    
    
        self.eeg_projection = ProjectionHead(embedding_dim=EEG_embedding).to(device)
        self.img_projection = ProjectionHead(embedding_dim=ImgEncoder_dim).to(device)
        self.temperature = temperature

    def forward(self, batch):
        eeg_features = self.eeg_encoder(batch["eeg"])  
        image_features = self.img_encoder(batch["image"])  
        eeg_embeddings = self.eeg_projection(eeg_features)
        image_embeddings = self.img_projection(image_features)

        logits = (eeg_embeddings @ image_embeddings.T) / self.temperature
        images_similarity = image_embeddings @ image_embeddings.T
        eeg_similarity = eeg_embeddings @ eeg_embeddings.T
        targets = F.softmax(
            (images_similarity + eeg_similarity) / 2 * self.temperature, dim=-1
        )
        eeg_loss = cross_entropy(logits, targets, reduction='none')
        images_loss = cross_entropy(logits.T, targets.T, reduction='none')
        loss =  (images_loss + eeg_loss) / 2.0  # shape: (batch_size)
        return loss.mean()



def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()
    
    
    
# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
    _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype

    y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device))
    assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
    omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
    omega = 1. / (temperature ** omega)

    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :] 
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
    return pe.type(dtype)

# patch dropout

class PatchDropout(nn.Module):
    def __init__(self, prob):
        super().__init__()
        assert 0 <= prob < 1.
        self.prob = prob

    def forward(self, x):
        if not self.training or self.prob == 0.:
            return x

        b, n, _, device = *x.shape, x.device

        batch_indices = torch.arange(b, device = device)
        batch_indices = rearrange(batch_indices, '... -> ... 1')
        num_patches_keep = max(1, int(n * (1 - self.prob)))
        patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices

        return x[batch_indices, patch_indices_keep]

# classes

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

class SimpleViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, patch_dropout = 0.5):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim)
        )

        self.patch_dropout = PatchDropout(patch_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        self.to_latent = nn.Identity()
        self.linear_head = nn.Linear(dim, num_classes)

    def forward(self, img):
        *_, h, w, dtype = *img.shape, img.dtype

        x = self.to_patch_embedding(img)
        pe = posemb_sincos_2d(x)
        x = rearrange(x, 'b ... d -> b (...) d') + pe

        x = self.patch_dropout(x)

        x = self.transformer(x)
        x = x.mean(dim = 1)

        x = self.to_latent(x)
        return x
    
    



class EEG2ImageDataset(Dataset):
    def __init__(self, path, resolution=None, **super_kwargs):
        print(super_kwargs)
        self.dataset_path = path
        self.eegs   = []
        self.images = []
        self.labels = []
        self.class_name = []
        self.eeg_feat = []
        cls_lst = [0, 1]
        self._raw_shape = [3, config.image_height, config.image_width]
        self.resolution = config.image_height
        self.has_labels  = True
        self.label_shape = [config.projection_dim]
        self.label_dim   = config.projection_dim
        self.name        = config.dataset_name
        self.image_shape = [3, config.image_height, config.image_width]
        self.num_channels = config.input_channel
        
        
    
        seed = 45
        torch.manual_seed(seed)
        np.random.seed(seed)

        
        self.eeg_model = SimpleViT(
            image_size = (128,440),
            patch_size = (8, 20),
            num_classes = 40,
            dim = 256,
            depth = 4,
            dim_head=16,
            heads = 16,
            mlp_dim = 16,
            channels = 1
        ).to(device)
        self.eeg_model = torch.nn.DataParallel(self.eeg_model).to(device)
        eegckpt   = '/home/ubuntu/chkpts/eeg_best_model.pth'
        eegcheckpoint = torch.load(eegckpt, map_location=device)
        self.eeg_model.load_state_dict(eegcheckpoint['model_state_dict'])
        
        
        chkp1 = "/home/ubuntu/bestckpt/model_v1_2_vit_best_model_updated.pth"
        chkp2 = "/home/ubuntu/bestckpt/model_v2_2_vit_best_model_updated.pth"
        chkp3 = "/home/ubuntu/bestckpt/model_v3_3_vit_best_model_updated.pth"
        chkp4 = "/home/ubuntu/bestckpt/model_4_org_vit_best_model_updated.pth"
        
        eeg_model1 = CLIPModel(image_encoder="/home/ubuntu/bestckpt/eegfeat_mid.pth").to(device)
        eeg_model2 = CLIPModel(image_encoder="/home/ubuntu/bestckpt/eegfeat_band.pth").to(device)
        eeg_model3 = CLIPModel(image_encoder="/home/ubuntu/bestckpt/eegfeat_high.pth").to(device)
        eeg_model4 = CLIPModel(image_encoder="/home/ubuntu/bestckpt/eegfeat_all.pth").to(device)
        self.eeg_model.eval()
        
        eegcheckpoint1 = torch.load(chkp1, map_location=device)
        eegcheckpoint2 = torch.load(chkp2, map_location=device)
        eegcheckpoint3 = torch.load(chkp3, map_location=device)
        eegcheckpoint4 = torch.load(chkp4, map_location=device)
        eeg_model1.load_state_dict(eegcheckpoint1)
        eeg_model2.load_state_dict(eegcheckpoint2)
        eeg_model3.load_state_dict(eegcheckpoint3)
        eeg_model4.load_state_dict(eegcheckpoint4)
        self.eeg_model = torch.nn.DataParallel(self.eeg_model).to(config.device)


        self.eegs = torch.load("/home/ubuntu/test_eeg_data.pt")
        self.images = torch.load("/home/ubuntu/test_images.pt")
        self.images = (self.images).byte()
        self.labels = torch.load("/home/ubuntu/test_labels.pt")

        dataset = TensorDataset(self.eegs, self.images)  
        batch_size = 16  
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        eeg_model1.eval() 
        eeg_model2.eval()  
        eeg_model3.eval()  
        eeg_model4.eval()   
        eeg_features1 = []
        eeg_features2 = []
        eeg_features3 = []
        eeg_features4 = []
        eeg_features5 = []
        
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Processing EEG Data"):
                eegs_batch, _ = batch  # Unpack the batch
                eegs_batch = eegs_batch.to(device)  # Move batch to the appropriate device
                
                eegs_batch = eegs_batch.permute(0, 2, 1)
                e6 = self.eeg_model(eegs_batch)
                
                eeg_feat_batch1 = eeg_model1.eeg_encoder(eegs_batch).to(device)
                eeg_feat_batch1 = eeg_model1.eeg_projection(eeg_feat_batch1).to(device)
                
                eeg_feat_batch2 = eeg_model2.eeg_encoder(eegs_batch).to(device)
                eeg_feat_batch2 = eeg_model2.eeg_projection(eeg_feat_batch2).to(device)
                
                eeg_feat_batch3 = eeg_model3.eeg_encoder(eegs_batch).to(device)
                eeg_feat_batch3 = eeg_model3.eeg_projection(eeg_feat_batch3).to(device)
                
                
                eeg_feat_batch4 = eeg_model4.eeg_encoder(eegs_batch).to(device)
                eeg_feat_batch4 = eeg_model4.eeg_projection(eeg_feat_batch4).to(device)
            
                eeg_features1.append(eeg_feat_batch1)
                eeg_features2.append(eeg_feat_batch2)
                eeg_features3.append(eeg_feat_batch3)
                eeg_features4.append(eeg_feat_batch4)
                eeg_features5.append(e6)

        
        eeg_features1 = torch.cat(eeg_features1, dim=0)
        eeg_features2 = torch.cat(eeg_features2, dim=0)     
        eeg_features3 = torch.cat(eeg_features3, dim=0)     
        eeg_features4 = torch.cat(eeg_features4, dim=0)
        eeg_features5 = torch.cat(eeg_features5, dim=0)     
                 
        self.eeg_feat = torch.cat((eeg_features1, eeg_features2, eeg_features3, eeg_features4), dim=1)



        eeg_feat_np = self.eeg_feat.cpu().numpy()
        eeg_feat_1_np = eeg_features1.cpu().numpy()
        eeg_feat_2_np = eeg_features2.cpu().numpy()
        eeg_feat_3_np = eeg_features3.cpu().numpy()
        eeg_feat_4_np = eeg_features4.cpu().numpy()

        self.eegs     = torch.from_numpy(np.array(self.eegs)).to(torch.float32)
        self.images   = torch.from_numpy(np.array(self.images)).to(torch.float32)
        self.eeg_feat = torch.from_numpy(np.array(self.eeg_feat_np)).to(torch.float32)
        self.labels   = torch.from_numpy(np.array(self.labels)).to(torch.int32)


    def __len__(self):
        return self.eegs.shape[0]

    def __getitem__(self, idx):
        eeg   = self.eegs[idx]
        norm  = torch.max(eeg) / 2.0
        eeg   =  ( eeg - norm ) / norm
        image = self.images[idx]
        con   = self.eeg_feat[idx]
        return image, con
    
    def get_label(self, idx):
        con = self.eeg_feat[idx]
        return con
